{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fd1ab3f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "import argparse\n",
    "import time\n",
    "\n",
    "#setup training parameters\n",
    "parser = argparse.ArgumentParser(description='PyTorch MNIST Training')\n",
    "parser.add_argument('--batch-size', type=int, default=128, metavar='N',\n",
    "                    help='input batch size for training (default: 128)')\n",
    "parser.add_argument('--test-batch-size', type=int, default=128, metavar='N',\n",
    "                    help='input batch size for testing (default: 128)')\n",
    "parser.add_argument('--epochs', type=int, default=10, metavar='N',\n",
    "                    help='number of epochs to train')\n",
    "parser.add_argument('--lr', type=float, default=0.01, metavar='LR',\n",
    "                    help='learning rate')\n",
    "parser.add_argument('--no-cuda', action='store_true', default=False,\n",
    "                    help='disables CUDA training')\n",
    "parser.add_argument('--seed', type=int, default=1, metavar='S',\n",
    "                    help='random seed (default: 1)')\n",
    "\n",
    "args = parser.parse_args(args=[]) \n",
    "\n",
    "# judge cuda is available or not\n",
    "use_cuda = not args.no_cuda and torch.cuda.is_available()\n",
    "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "#device = torch.device(\"cpu\")\n",
    "\n",
    "torch.manual_seed(args.seed)\n",
    "kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}\n",
    "\n",
    "# setup data loader\n",
    "transform=transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.1307,), (0.3081,))\n",
    "        ])\n",
    "trainset = datasets.MNIST('../data', train=True, download=True,\n",
    "                   transform=transform)\n",
    "testset = datasets.MNIST('../data', train=False,\n",
    "                   transform=transform)\n",
    "train_loader = torch.utils.data.DataLoader(trainset,batch_size=args.batch_size, shuffle=True,**kwargs)\n",
    "test_loader = torch.utils.data.DataLoader(testset,batch_size=args.test_batch_size, shuffle=False, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4102df45",
   "metadata": {},
   "outputs": [],
   "source": [
    "#define fully connected network\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(28*28, 128)\n",
    "        self.fc2 = nn.Linear(128, 64)\n",
    "        self.fc3 = nn.Linear(64, 32)\n",
    "        self.fc4 = nn.Linear(32, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.fc3(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.fc4(x)\n",
    "        output = F.log_softmax(x, dim=1)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "24bf698b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#train function\n",
    "def train(args, model, device, train_loader, optimizer, epoch):\n",
    "    model.train()\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        data = data.view(data.size(0),28*28)\n",
    "        \n",
    "        #clear gradients\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        #compute loss\n",
    "        loss = F.cross_entropy(model(data), target)\n",
    "        \n",
    "        #get gradients and update\n",
    "        loss.backward()\n",
    "        optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d6f4a813",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cpu\n",
      "Epoch 1: 13s, trn_loss: 1.6456, trn_acc: 47.92%, test_loss: 1.6397, test_acc: 48.43%\n",
      "Epoch 2: 14s, trn_loss: 0.4936, trn_acc: 86.18%, test_loss: 0.4827, test_acc: 86.61%\n",
      "Epoch 3: 13s, trn_loss: 0.3719, trn_acc: 89.15%, test_loss: 0.3611, test_acc: 89.56%\n",
      "Epoch 4: 14s, trn_loss: 0.3176, trn_acc: 90.72%, test_loss: 0.3129, test_acc: 90.72%\n",
      "Epoch 5: 14s, trn_loss: 0.2818, trn_acc: 91.81%, test_loss: 0.2804, test_acc: 91.97%\n",
      "Epoch 6: 14s, trn_loss: 0.2554, trn_acc: 92.64%, test_loss: 0.2556, test_acc: 92.72%\n",
      "Epoch 7: 14s, trn_loss: 0.2312, trn_acc: 93.27%, test_loss: 0.2319, test_acc: 93.13%\n",
      "Epoch 8: 14s, trn_loss: 0.2107, trn_acc: 93.91%, test_loss: 0.2114, test_acc: 93.69%\n",
      "Epoch 9: 13s, trn_loss: 0.1904, trn_acc: 94.58%, test_loss: 0.1910, test_acc: 94.37%\n",
      "Epoch 10: 13s, trn_loss: 0.1736, trn_acc: 94.95%, test_loss: 0.1766, test_acc: 94.71%\n"
     ]
    }
   ],
   "source": [
    "#predict function\n",
    "def eval_test(model, device, test_loader):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            data = data.view(data.size(0),28*28)\n",
    "            output = model(data)\n",
    "            test_loss += F.cross_entropy(output, target, size_average=False).item()\n",
    "            pred = output.max(1, keepdim=True)[1]\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "    test_accuracy = correct / len(test_loader.dataset)\n",
    "    return test_loss, test_accuracy\n",
    "\n",
    "#main function, train the dataset and print train loss, test loss for each epoch\n",
    "def main():\n",
    "    model = Net().to(device)\n",
    "    print(device)\n",
    "    optimizer = optim.SGD(model.parameters(), lr=args.lr)\n",
    "    for epoch in range(1, args.epochs + 1):\n",
    "        start_time = time.time()\n",
    "        \n",
    "        #training\n",
    "        train(args, model, device, train_loader, optimizer, epoch)\n",
    "        \n",
    "        #get trnloss and testloss\n",
    "        trnloss, trnacc = eval_test(model, device, train_loader)\n",
    "        tstloss, tstacc = eval_test(model, device, test_loader)\n",
    "        \n",
    "        #print trnloss and testloss\n",
    "        print('Epoch '+str(epoch)+': '+str(int(time.time()-start_time))+'s', end=', ')\n",
    "        print('trn_loss: {:.4f}, trn_acc: {:.2f}%'.format(trnloss, 100. * trnacc), end=', ')\n",
    "        print('test_loss: {:.4f}, test_acc: {:.2f}%'.format(tstloss, 100. * tstacc))\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "abf649a0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.8.1\n"
     ]
    }
   ],
   "source": [
    "print(torch.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c8c00cb",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
